Skip to content

Two locus general matrix#3426

Open
lkirk wants to merge 17 commits intotskit-dev:mainfrom
lkirk:two-locus-general-matrix
Open

Two locus general matrix#3426
lkirk wants to merge 17 commits intotskit-dev:mainfrom
lkirk:two-locus-general-matrix

Conversation

@lkirk
Copy link
Contributor

@lkirk lkirk commented Mar 10, 2026

Description

This is the last of the required components for the LD matrix methods. I wanted feedback on the API before I add documentation, but this method is complete and tested. The final things to do are to leak check the cpython code and add some documentation.

This feature enables a user to implement their own two-locus count statistic in python, similar to ts.sample_count_stat. User functions take two arguments, the first is a matrix of haplotype counts and the second is a vector of sample set sizes. For instance, this is how we would implement D with this api:

def D(X, n):
    pAB, pAb, paB = X / n
    pA = pAb + pAB
    pB = paB + pAB
    return pAB - (pA * pB)

Since this API supports multiallelic sites, the user can also pass a normalisation function to control how the data is normalised across multiple alleles. The normalisation function is only run when computing over multiallelic sites. I've set the default to be $1/(n_A n_B)$, which is simply the arithmetic mean of the alleles in a given pair of sites. This will suffice in the majority of cases (the only outlier is $r^2$, for which there is already a python API). We also support computing statistics between sample sets.

The user would use the above summary function like this:

ts.two_locus_count_stat(ts.samples(), D, 1)

Where 1 specifies the length of the output array, we always require 1 dimension -- same as the ts.sample_count_stat function.

PR Checklist:

  • Tests that fully cover new/changed functionality.
  • Documentation including tutorial content if appropriate.
  • Changelogs, if there are API changes.

@codecov
Copy link

codecov bot commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 95.43379% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.96%. Comparing base (afaf3b9) to head (bd0a1a5).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3426      +/-   ##
==========================================
+ Coverage   91.92%   91.96%   +0.03%     
==========================================
  Files          37       37              
  Lines       32153    32353     +200     
  Branches     5143     5144       +1     
==========================================
+ Hits        29556    29752     +196     
- Misses       2264     2270       +6     
+ Partials      333      331       -2     
Flag Coverage Δ
C 82.71% <94.11%> (+<0.01%) ⬆️
c-python 77.68% <93.59%> (+0.34%) ⬆️
python-tests 96.40% <100.00%> (+<0.01%) ⬆️
python-tests-no-jit 33.20% <18.75%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Components Coverage Δ
Python API 98.70% <100.00%> (+<0.01%) ⬆️
Python C interface 91.32% <94.62%> (+0.08%) ⬆️
C library 88.89% <100.00%> (+0.03%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

out:
return ret;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function now serves as an inner wrapper. The the general stat accepts the summary function params so that the CPython code can pass them directly. All of the specialized stats functions call this function.

@apragsdale
Copy link

Thank you for opening this, @lkirk. I'm excited to see this implemented! Do we need to do any testing to demonstrate that there are no memory leaks or anything like that, which wouldn't be included in the test suite?

I know it could be found in the tests, but it may be helpful to spell out here how the API would work for a two- or more-way stat? Would it be:

ts.two_locus_count_stat([sample_list_1, sample_list_2], two_way_stat_func, 2)

@lkirk
Copy link
Contributor Author

lkirk commented Mar 14, 2026

Yes, this needs leak checking and documentation, which I can add. I mostly wanted to make sure the user interface made sense first.

I know it could be found in the tests, but it may be helpful to spell out here how the API would work for a two- or more-way stat? Would it be:

ts.two_locus_count_stat([sample_list_1, sample_list_2], two_way_stat_func, 2)

The result_dim argument to the function gives the dimensions of the summary function output. The output is required to be 1D, so result_dim really tells us the length of the returned vector. In most cases, it'll be 1. See this line in the tests. So, if your summary function returned 1 value for a pair of sites, then the function call will look like this:

ts.two_locus_count_stat([sample_list_1, sample_list_2], two_way_stat_func, 1)

two_locus_count_stat was designed to work like sample_count_stat

@petrelharp
Copy link
Contributor

Hi, @lkirk! This looks pretty straightforward. To have a careful opinion about the API, I think I need to see a reasonably careful docstring? For instance, what exactly are the arguments to f and norm_f? I can probably figure it out by tracing through the code, but it'll be less error-prone if you write it down. I'll have a look through for other issues, but it will be much easier to have the description in words of what exactly it's trying to do.

@lkirk
Copy link
Contributor Author

lkirk commented Mar 14, 2026

Hi @petrelharp, thanks for taking a look. I didn't offer much of a description before, does this help?

Summary Function

In the sample_count_stat api, there are 3 required parameters:

  1. sample_sets: List of lists of node ids.
  2. f: A summary function that takes a one-dimensional array of length equal to the number of sample sets and returns a one-dimensional array.
  3. output_dim: The length of the summary function's return value.

The two_locus_count_stat function (see function signature below) takes the same basic inputs except that f takes a matrix of haplotype counts and a vector of sample set sizes as input.

I've been writing summary functions with the signature f(X, n)

Where X is a matrix whose rows correspond to sample sets and columns correspond to haplotype counts ($w_{Ab}, w_{aB}, w_{AB}$). Here's a sample of what that looks like:

        sample_set 1, sample_set 2
w_Ab  [[9.            9.]
w_aB   [0.            0.]
w_AB   [0.            0.]]

Why lay the data out this way? Because numpy is row-major, iteration over the arrays gives us rows. That makes it easy to select the haplotype counts in one line:

AB, Ab, aB = X

Note: under the hood, the data is still laid out in an optimal way (at least as far as I can tell -- see this note)

In addition, the sample set sizes is shaped so that we can use it to normalize the counts in one go (n is a vector of sample set sizes)

pAB, pAb, paB = X / n

Finally, since we're no longer controlling the summary functions ourselves, the user has control over polarisation (see polarised parameter).

Normalisation

Perhaps the most clunky thing about this api is that the user will need a normalisation function for multiallelic data (norm_f parameter). Its function signature is the same as the code internal to ld_matrix(): f(X, n, nA, nB) -> [out_dims] where X and n are the same, nA and nB are the number of A and B alleles, respectively. Its output dimensions should be the same as our summary function (we also validate this just like we validate the output dims of our summary function).

The default normalization function of two_locus_count_stat is the total (or uniform) normalization, we average all results into one. Here's how I have defined it in the function signature

lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0),

I would have loved to be able to pass np.mean as a summary function instead, but we do need haplotype counts to norm $r^2$ and this keeps our original ld_matrix implementation untouched (we use the same exact machinery for this method, so we have to follow the C api).

sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
row_positions, out_cols, col_positions, options, result);
}
out:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why codecov says this line is not covered; do you? (sticking some prints or asserts in here can help track that down)

Image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Judging from the coverage of the rest of that function (every goto being hit), my best guess is that the compiler is optimizing that label out. In fact, I could just change all of those goto ret to return ret and remove the out: label since we're not doing any sort of cleanup at the end of this function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't do that; it's nice to have that pattern in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I see why this is complaining here. It's because right before this line is if (A) { ... } else if (B) { ... }, but there are no tests in which both A and B are false. And in fact it's impossible for this to happen. So the "else if ( )" should be an "else"; and if you want to be extra paranoid stick a tsk_bug_assert(B) in the code at the top of the else { }.

Every time I think "oh no codecov has this one wrong" it turns out that it's got a good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up removing the branch altogether and added tsk_bug_assert(mode_branch)

@petrelharp
Copy link
Contributor

I've looked at the code pretty carefully, and things look good besides the comments. Nice work!

@lkirk
Copy link
Contributor Author

lkirk commented Mar 14, 2026

Hi @petrelharp, I appreciate you taking a look. I've responded to everything with how I plan to resolve the comment. I did have one clarifying question here. I will clean this up tomorrow.

@petrelharp
Copy link
Contributor

Hi, @lkirk - I went ahead and made a few adjustments to how the examples are chosen for testing; if you see something wrong or disagree, feel free to revert. Rationale:

  • Instead of listing "don't run these tests since they're slow" it's now "do run these tests" because otherwise if someone adds a new slow example to get_example_tree_sequence it won't slow things down mysteriously
  • Instead of pulling that one test case out I just call the function that generates it. This makes it a lot more clear what's going on.

@petrelharp
Copy link
Contributor

Hm, I also changed the tests that pulled out "all_fields" to explicitly call that function. Now I see that we have a fixture defined for that one. Do you mind removing that function call and using the fixture instead? So that just means doing

def test_general_one_way_two_locus_stat_multiallelic(stat, ts_fixture):

and then using ts instead of ts_fixture in the code (or renaming it).

I'd do this but don't want to risk making conflicting edits.

@jeromekelleher
Copy link
Member

I'm happy to look over this at some stage - please ping me when you'd like input!

sample_set_sizes,
sample_sets,
f,
norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here let's make this more explicit (and hence easier to read): earlier do

if norm_f is None:
     norm_f = lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that in the description you wrote

lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0),

which looks more idiomatic, but I don't know?

@petrelharp
Copy link
Contributor

Ah, just noticed your expanded description. So, to answer your question: the API looks good. I may have some other comments now, but will wait for you to ping me when it's ready to read carefuly (including the docstring).

@lkirk
Copy link
Contributor Author

lkirk commented Mar 16, 2026

thanks @petrelharp, glad to know it seems reasonable. I'm currently running down a bug that appears at first glance to be macos specific (it shook out with my updates). I'm hoping to be done with this in the next few hours, I'll summarize the changes and ping you when ready. For now, you might see a few more spurious pushes until I figure out what's going on. I don't have any way to test this locally on a mac.

col_positions,
mode,
)
if result_dim == 1: # drop dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this dropping dimension depending on the output of f is consistent with how ts.sample_count_stat works: for instance,

>>> import tskit, msprime
>>> ts = msprime.sim_ancestry(2, sequence_length=10, recombination_rate=0.1)
>>> mts = msprime.sim_mutations(ts, rate=0.1)
>>> mts.sample_count_stat([mts.samples()], lambda x: [2, 3], strict=False, output_dim=2)
array([3.4, 5.1])
>>> mts.sample_count_stat([mts.samples()], lambda x: [2, 3], strict=False, output_dim=2, windows=[0, 3, 10])
array([[2.66666667, 4.        ],
       [3.71428571, 5.57142857]])
>>> mts.sample_count_stat([mts.samples()], lambda x: [2], strict=False, output_dim=1, windows=[0, 3, 10])
array([[2.66666667],
       [3.71428571]])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But correct me if I'm wrong!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong, but it looks to me like to keep consistency, we would want

ts.two_locus_count_stat([0, 1, 2, 3])

to be the same thing as

ts.two_locus_count_stat([[0, 1, 2, 3]])  # <- note now a 2d array

except that the first case has dimension dropped.

However this makes no sense here because the output shape is decoupled from the sample sets, thanks to passing through the summary function. (Right?)

return result.reshape(result.shape[:2])
# Orient the data so that the first dimension is the sample set so that
# we get one LD matrix per sample set.
return result.swapaxes(0, 2).swapaxes(1, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this: since the summary function takes in all the sample sets at once, it doesn't have any notion of "one LD matrix per sample set"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and if this is necessary, note that np.moveaxis would do this in a single operation)

lkirk and others added 7 commits March 16, 2026 21:55
*Python tests*
Overhaul python testing of the general stat functions. Remove the
dependence on the example tree sequences, opting instead to simulate a
couple of examples directly. Use these simulated trees in test fixtures,
scoped at the module level. This streamlines the test parameterization a
lot.

Use the single stat site names from the summary function definitions.

*CPython tests*
Add a multiallelic tree sequence to test normalisation function
validation and errors. Remove one more occurrence of `np.expand_dims`.

*trees.c*
Remove the unnecessary branch in
tsk_treeseq_two_locus_count_general_stat, improving the code coverage.

*trees.py*
Default normalisation function can be None, applying default at runtime.
Simplifies calling code and is more in line with the rest of the API.
@lkirk lkirk force-pushed the two-locus-general-matrix branch from ed3f0fb to bd0a1a5 Compare March 17, 2026 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants